from batchedGeneration import TransformersInterface
from transformers import AutoTokenizer, AutoModelForCausalLM
import pickle as pkl

# Initialize the tokenizer and model
model_name = 'allenai/OLMo-2-1124-7B-DPO'  # Ensure this is the correct model name
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, token="")

pad_token_id = tokenizer.pad_token_id

# Load the model
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, token="", device_map="auto")

# Extract vocabulary from tokenizer
vocabulary = set(tokenizer.get_vocab().keys())

f = open("prompts.txt", "r", errors="ignore")
lines = []
for line in f.readlines():
    lines.append(line.rstrip())
interface = TransformersInterface(model, vocabulary, tokenizer)
normal_output_file = open("normal_olmo_raw.txt", "w")
perturbed_output_file = open("perturbed_olmo_raw.txt", "w")
normal_responses, perturbed_responses = interface.generate_response(lines, max_new_tokens=512, batch_size=1, perplexity_outfile="raw.txt")
for i in range(len(normal_responses)):
    normal_output_file.write(str([normal_responses[i]]) + "\n")
    perturbed_output_file.write(str([perturbed_responses[i]]) + "\n")
    print("Processed:", i)
normal_output_file.close()
perturbed_output_file.close()
f.close()